#!/usr/bin/env python3
"""
Phase 2: Baseline Fragility Documentation
Tables 2-5: Jackknife, leave-one-out by region, progressive exclusion, CCA dummies.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Country Sets ───────────────────────────────────────────────────────────
CCA_ALL = {'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
           'TJK', 'TKM', 'UKR', 'UZB'}
CCA_NON_COMMODITY = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}
BALTIC = {'EST', 'LVA', 'LTU'}
CEE = {'ALB', 'BGR', 'BIH', 'HRV', 'CZE', 'HUN', 'MKD', 'MNE', 'POL',
       'ROU', 'SRB', 'SVK', 'SVN'}

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    resid = y - X @ gls.beta
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
        'ssr': float(np.sum(resid ** 2)),
    }

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 2: BASELINE FRAGILITY DOCUMENTATION")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

base_sample = est.dropna(subset=base_vars + ['ca_gdp']).copy()
print(f"Base sample: {base_sample['iso3'].nunique()} countries, {len(base_sample):,} obs")

# ── Baseline ───────────────────────────────────────────────────────────────
r_full = run_gls(base_sample, 'ca_gdp', base_vars)
print(f"\nFull sample: Z₁={r_full['coefficients']['Z_1']:.2f}{star(r_full['p_values']['Z_1'])} "
      f"(p={r_full['p_values']['Z_1']:.4f}), R²={r_full['r_squared']:.4f}")

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 2: JACKKNIFE — DROP EACH CCA COUNTRY ONE AT A TIME
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 2: CCA JACKKNIFE (DROP ONE AT A TIME)")
print("=" * 70)

cca_in_sample = sorted(CCA_ALL & set(base_sample['iso3'].unique()))
print(f"CCA countries in sample: {cca_in_sample}")

jk_rows = []
for iso in cca_in_sample:
    sub = base_sample[base_sample['iso3'] != iso].copy()
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r is None:
        continue
    n_dropped = base_sample[base_sample['iso3'] == iso]['iso3'].count()
    commodity = "Y" if iso in CCA_NON_COMMODITY else ("COM" if iso in CCA_ALL - CCA_NON_COMMODITY else "N")
    row = {
        'dropped_country': iso,
        'type': 'non_commodity' if iso in CCA_NON_COMMODITY else 'commodity',
        'n_countries': r['n_countries'],
        'n_obs_dropped': n_dropped,
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
    }
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    jk_rows.append(row)
    print(f"  Drop {iso:3s} ({row['type']:<13s}): Z₁={row['Z_1_coef']:7.2f}{star(row['Z_1_pval'])} "
          f"(p={row['Z_1_pval']:.4f}), N={r['n_countries']}")

jk_df = pd.DataFrame(jk_rows)
jk_df.to_csv(TABLE_DIR / "table2_cca_jackknife.csv", index=False)

# Markdown
md = ["# Table 2: CCA Jackknife — Drop Each Country One at a Time\n"]
md.append("| Dropped | Type | N_c | Z₁ | SE | p-val | Z₂ | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|---|")
# Full sample row
md.append(f"| *Full sample* | — | {r_full['n_countries']} | "
          f"{r_full['coefficients']['Z_1']:.2f}{star(r_full['p_values']['Z_1'])} | "
          f"{r_full['std_errors']['Z_1']:.2f} | {r_full['p_values']['Z_1']:.4f} | "
          f"{r_full['coefficients']['Z_2']:.3f}{star(r_full['p_values']['Z_2'])} | "
          f"{r_full['p_values']['Z_2']:.4f} | {r_full['r_squared']:.4f} |")
for _, row in jk_df.iterrows():
    md.append(f"| {row['dropped_country']} | {row['type']} | {row['n_countries']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_se']:.2f} | "
              f"{row['Z_1_pval']:.4f} | {row['Z_2_coef']:.3f}{star(row['Z_2_pval'])} | "
              f"{row['Z_2_pval']:.4f} | {row['r_squared']:.4f} |")
(TABLE_DIR / "table2_cca_jackknife.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 3: LEAVE-ONE-REGION-OUT
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 3: LEAVE-ONE-REGION-OUT")
print("=" * 70)

region_sets = {
    'CCA_commodity': CCA_ALL - CCA_NON_COMMODITY,
    'CCA_non_commodity': CCA_NON_COMMODITY,
    'CCA_all': CCA_ALL,
    'Baltic': BALTIC,
    'CEE': CEE,
    'All_transition': CCA_ALL | BALTIC | CEE,
}

region_rows = []
for region_name, country_set in region_sets.items():
    in_sample = country_set & set(base_sample['iso3'].unique())
    if not in_sample:
        continue
    sub = base_sample[~base_sample['iso3'].isin(country_set)].copy()
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r is None:
        continue
    row = {
        'dropped_region': region_name,
        'n_dropped': len(in_sample),
        'n_countries': r['n_countries'],
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
    }
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    region_rows.append(row)
    print(f"  Drop {region_name:<20s} ({row['n_dropped']:2d}c): Z₁={row['Z_1_coef']:7.2f}{star(row['Z_1_pval'])} "
          f"(p={row['Z_1_pval']:.4f}), R²={row['r_squared']:.4f}")

region_df = pd.DataFrame(region_rows)
region_df.to_csv(TABLE_DIR / "table3_leave_region_out.csv", index=False)

# Markdown
md = ["# Table 3: Leave-One-Region-Out\n"]
md.append("| Dropped Region | N dropped | N_c remaining | Z₁ | SE | p-val | R² |")
md.append("|---|---|---|---|---|---|---|")
md.append(f"| *Full sample* | 0 | {r_full['n_countries']} | "
          f"{r_full['coefficients']['Z_1']:.2f}{star(r_full['p_values']['Z_1'])} | "
          f"{r_full['std_errors']['Z_1']:.2f} | {r_full['p_values']['Z_1']:.4f} | "
          f"{r_full['r_squared']:.4f} |")
for _, row in region_df.iterrows():
    md.append(f"| {row['dropped_region']} | {row['n_dropped']:.0f} | {row['n_countries']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_se']:.2f} | "
              f"{row['Z_1_pval']:.4f} | {row['r_squared']:.4f} |")
(TABLE_DIR / "table3_leave_region_out.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 4: PROGRESSIVE EXCLUSION
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 4: PROGRESSIVE EXCLUSION")
print("=" * 70)

progressive_steps = [
    ("Full sample", set()),
    ("Excl CCA commodity (5)", CCA_ALL - CCA_NON_COMMODITY),
    ("Excl CCA non-commodity (8)", CCA_NON_COMMODITY),
    ("Excl all CCA (13)", CCA_ALL),
    ("Excl CCA + Baltic (16)", CCA_ALL | BALTIC),
    ("Excl all transition (29)", CCA_ALL | BALTIC | CEE),
]

prog_rows = []
for step_name, excl_set in progressive_steps:
    sub = base_sample[~base_sample['iso3'].isin(excl_set)].copy()
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r is None:
        continue
    row = {
        'step': step_name,
        'n_excluded': len(excl_set & set(base_sample['iso3'].unique())),
        'n_countries': r['n_countries'],
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
    }
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    prog_rows.append(row)
    print(f"  {step_name:<35s}: Z₁={row['Z_1_coef']:7.2f}{star(row['Z_1_pval'])} "
          f"(p={row['Z_1_pval']:.4f}), N_c={r['n_countries']}")

prog_df = pd.DataFrame(prog_rows)
prog_df.to_csv(TABLE_DIR / "table4_progressive_exclusion.csv", index=False)

md = ["# Table 4: Progressive Exclusion\n"]
md.append("| Step | N excluded | N_c | Z₁ | SE | p-val | Z₂ | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|---|")
for _, row in prog_df.iterrows():
    md.append(f"| {row['step']} | {row['n_excluded']:.0f} | {row['n_countries']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_se']:.2f} | "
              f"{row['Z_1_pval']:.4f} | {row['Z_2_coef']:.3f}{star(row['Z_2_pval'])} | "
              f"{row['Z_2_pval']:.4f} | {row['r_squared']:.4f} |")
(TABLE_DIR / "table4_progressive_exclusion.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 5: CCA DUMMIES AS CONTROLS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 5: CCA DUMMIES AS CONTROLS")
print("=" * 70)

dummy_specs = [
    ("Baseline", base_vars),
    ("+ CCA dummy", base_vars + ['is_cca']),
    ("+ CCA non-commodity dummy", base_vars + ['is_cca_non_commodity']),
    ("+ Transition dummy", base_vars + ['is_transition']),
    ("+ CCA + CCA_nc dummies", base_vars + ['is_cca', 'is_cca_non_commodity']),
    ("+ CCA + Baltic + CEE dummies", base_vars + ['is_cca', 'is_baltic', 'is_cee']),
    ("+ All transition dummies", base_vars + ['is_cca', 'is_cca_non_commodity',
                                               'is_baltic', 'is_cee']),
]

dummy_rows = []
for spec_name, spec_vars in dummy_specs:
    r = run_gls(base_sample, 'ca_gdp', spec_vars)
    if r is None:
        continue
    row = {
        'specification': spec_name,
        'n_countries': r['n_countries'],
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
    }
    for v in demo_vars + [v for v in spec_vars if v.startswith('is_')]:
        row[f'{v}_coef'] = r['coefficients'].get(v, np.nan)
        row[f'{v}_se'] = r['std_errors'].get(v, np.nan)
        row[f'{v}_pval'] = r['p_values'].get(v, np.nan)
    dummy_rows.append(row)
    z1c = r['coefficients']['Z_1']
    z1p = r['p_values']['Z_1']
    extra = ""
    for dv in [v for v in spec_vars if v.startswith('is_')]:
        dc = r['coefficients'][dv]
        dp = r['p_values'][dv]
        extra += f", {dv}={dc:.2f}{star(dp)}"
    print(f"  {spec_name:<35s}: Z₁={z1c:7.2f}{star(z1p)} (p={z1p:.4f}){extra}")

dummy_df = pd.DataFrame(dummy_rows)
dummy_df.to_csv(TABLE_DIR / "table5_cca_dummies.csv", index=False)

md = ["# Table 5: CCA Dummies as Controls\n"]
md.append("| Specification | Z₁ | SE | p-val | R² | Dummy coefficients |")
md.append("|---|---|---|---|---|---|")
for _, row in dummy_df.iterrows():
    dummies_str = ""
    for dv in ['is_cca', 'is_cca_non_commodity', 'is_transition', 'is_baltic', 'is_cee']:
        if f'{dv}_coef' in row and pd.notna(row.get(f'{dv}_coef', np.nan)):
            dc = row[f'{dv}_coef']
            dp = row[f'{dv}_pval']
            dummies_str += f"{dv}={dc:.2f}{star(dp)} "
    md.append(f"| {row['specification']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_se']:.2f} | "
              f"{row['Z_1_pval']:.4f} | {row['r_squared']:.4f} | {dummies_str.strip()} |")
(TABLE_DIR / "table5_cca_dummies.md").write_text("\n".join(md))

print(f"\nAll Phase 2 tables saved to {TABLE_DIR}")
print("Phase 2 complete.")
